
package org.hhl.hbaseETL.hbase

import java.util.UUID

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.permission.FsPermission
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hbase._
import org.apache.hadoop.hbase.classification.InterfaceAudience
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
import org.apache.hadoop.hbase.mapreduce._
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.partition.TotalOrderPartitioner
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.{SerializableWritable, SparkContext}

import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag



/**
  * HBaseContext is a façade for HBase operations
  * like bulk put, get, increment, delete, and scan
  *
  * HBaseContext will take the responsibilities
  * of disseminating the configuration information
  * to the working and managing the life cycle of Connections.
  */
@InterfaceAudience.Public
class HBaseContext(@transient sc: SparkContext,
                   @transient val config: Configuration,
                   val tmpHdfsConfgFile: String = null)
  extends Serializable with Logging {

  @transient var credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
  @transient var tmpHdfsConfiguration: Configuration = config
  @transient var appliedCredentials = false
  @transient val job = Job.getInstance(config)
  TableMapReduceUtil.initCredentials(job)
  val broadcastedConf = sc.broadcast(new SerializableWritable(config))
  val credentialsConf = sc.broadcast(new SerializableWritable(job.getCredentials))

  LatestHBaseContextCache.latest = this

  if (tmpHdfsConfgFile != null && config != null) {
    val fs = FileSystem.newInstance(config)
    val tmpPath = new Path(tmpHdfsConfgFile)
    if (!fs.exists(tmpPath)) {
      val outputStream = fs.create(tmpPath)
      config.write(outputStream)
      outputStream.close()
    } else {
      logWarning("tmpHdfsConfigDir " + tmpHdfsConfgFile + " exist!!")
    }
  }

  /**
    * A simple enrichment of the traditional Spark RDD foreachPartition.
    * This function differs from the original in that it offers the
    * developer access to a already connected Connection object
    *
    * Note: Do not close the Connection object.  All Connection
    * management is handled outside this method
    *
    * @param rdd Original RDD with data to iterate over
    * @param f   Function to be given a iterator to iterate through
    *            the RDD values and a Connection object to interact
    *            with HBase
    */
  def foreachPartition[T](rdd: RDD[T],
                          f: (Iterator[T], Connection) => Unit): Unit = {
    rdd.foreachPartition(
      it => hbaseForeachPartition(broadcastedConf, it, f))
  }



  /**
    * A simple enrichment of the traditional Spark RDD mapPartition.
    * This function differs from the original in that it offers the
    * developer access to a already connected Connection object
    *
    * Note: Do not close the Connection object.  All Connection
    * management is handled outside this method
    *
    * @param rdd Original RDD with data to iterate over
    * @param mp  Function to be given a iterator to iterate through
    *            the RDD values and a Connection object to interact
    *            with HBase
    * @return Returns a new RDD generated by the user definition
    *         function just like normal mapPartition
    */
  def mapPartitions[T, R: ClassTag](rdd: RDD[T],
                                    mp: (Iterator[T], Connection) => Iterator[R]): RDD[R] = {

    rdd.mapPartitions[R](it => hbaseMapPartition[T, R](broadcastedConf,
      it,
      mp))

  }



  /**
    * A simple abstraction over the HBaseContext.foreachPartition method.
    *
    * It allow addition support for a user to take RDD
    * and generate puts and send them to HBase.
    * The complexity of managing the Connection is
    * removed from the developer
    *
    * @param rdd       Original RDD with data to iterate over
    * @param tableName The name of the table to put into
    * @param f         Function to convert a value in the RDD to a HBase Put
    */
  def bulkPut[T](rdd: RDD[T], tableName: TableName, f: (T) => Put) {

    val tName = tableName.getName
    rdd.foreachPartition(
      it => hbaseForeachPartition[T](
        broadcastedConf,
        it,
        (iterator, connection) => {
          val m = connection.getBufferedMutator(TableName.valueOf(tName))
          iterator.foreach(T => m.mutate(f(T)))
          m.flush()
          m.close()
        }))
  }

  def applyCreds[T]() {
    credentials = SparkHadoopUtil.get.getCurrentUserCredentials()

    logDebug("appliedCredentials:" + appliedCredentials + ",credentials:" + credentials)

    if (!appliedCredentials && credentials != null) {
      appliedCredentials = true

      @transient val ugi = UserGroupInformation.getCurrentUser
      ugi.addCredentials(credentials)
      // specify that this is a proxy user
      ugi.setAuthenticationMethod(AuthenticationMethod.PROXY)

      ugi.addCredentials(credentialsConf.value.value)
    }
  }



  /**
    * A simple abstraction over the HBaseContext.foreachPartition method.
    *
    * It allow addition support for a user to take a RDD and generate delete
    * and send them to HBase.  The complexity of managing the Connection is
    * removed from the developer
    *
    * @param rdd       Original RDD with data to iterate over
    * @param tableName The name of the table to delete from
    * @param f         Function to convert a value in the RDD to a
    *                  HBase Deletes
    * @param batchSize The number of delete to batch before sending to HBase
    */
  def bulkDelete[T](rdd: RDD[T], tableName: TableName,
                    f: (T) => Delete, batchSize: Integer) {
    bulkMutation(rdd, tableName, f, batchSize)
  }



  /**
    * Under lining function to support all bulk mutations
    *
    * May be opened up if requested
    */
  private def bulkMutation[T](rdd: RDD[T], tableName: TableName,
                              f: (T) => Mutation, batchSize: Integer) {

    val tName = tableName.getName
    rdd.foreachPartition(
      it => hbaseForeachPartition[T](
        broadcastedConf,
        it,
        (iterator, connection) => {
          val table = connection.getTable(TableName.valueOf(tName))
          val mutationList = new java.util.ArrayList[Mutation]
          iterator.foreach(T => {
            mutationList.add(f(T))
            if (mutationList.size >= batchSize) {
              table.batch(mutationList, null)
              mutationList.clear()
            }
          })
          if (mutationList.size() > 0) {
            table.batch(mutationList, null)
            mutationList.clear()
          }
          table.close()
        }))
  }

  /**
    * A simple abstraction over the HBaseContext.mapPartition method.
    *
    * It allow addition support for a user to take a RDD and generates a
    * new RDD based on Gets and the results they bring back from HBase
    *
    * @param rdd           Original RDD with data to iterate over
    * @param tableName     The name of the table to get from
    * @param makeGet       function to convert a value in the RDD to a
    *                      HBase Get
    * @param convertResult This will convert the HBase Result object to
    *                      what ever the user wants to put in the resulting
    *                      RDD
    *                      return            new RDD that is created by the Get to HBase
    */
  def bulkGet[T, U: ClassTag](tableName: TableName,
                              batchSize: Integer,
                              rdd: RDD[T],
                              makeGet: (T) => Get,
                              convertResult: (Result) => U): RDD[U] = {

    val getMapPartition = new GetMapPartition(tableName,
      batchSize,
      makeGet,
      convertResult)

    rdd.mapPartitions[U](it =>
      hbaseMapPartition[T, U](
        broadcastedConf,
        it,
        getMapPartition.run))
  }


  /**
    * This function will use the native HBase TableInputFormat with the
    * given scan object to generate a new RDD
    *
    * @param tableName the name of the table to scan
    * @param scan      the HBase scan object to use to read data from HBase
    * @param f         function to convert a Result object from HBase into
    *                  what the user wants in the final generated RDD
    * @return new RDD with results from scan
    */
  def hbaseRDD[U: ClassTag](tableName: TableName, scan: Scan,
                            f: ((ImmutableBytesWritable, Result)) => U): RDD[U] = {

    val job: Job = Job.getInstance(getConf(broadcastedConf))

    TableMapReduceUtil.initCredentials(job)
    TableMapReduceUtil.initTableMapperJob(tableName, scan,
      classOf[IdentityTableMapper], null, null, job)

    val jconf = new JobConf(job.getConfiguration)
    SparkHadoopUtil.get.addCredentials(jconf)
    new NewHBaseRDD(sc,
      classOf[TableInputFormat],
      classOf[ImmutableBytesWritable],
      classOf[Result],
      job.getConfiguration,
      this).map(f)
  }

  /**
    * A overloaded version of HBaseContext hbaseRDD that defines the
    * type of the resulting RDD
    *
    * @param tableName the name of the table to scan
    * @param scans     the HBase scan object to use to read data from HBase
    * @return New RDD with results from scan
    *
    */
  def hbaseRDD(tableName: TableName, scans: Scan):
  RDD[(ImmutableBytesWritable, Result)] = {

    hbaseRDD[(ImmutableBytesWritable, Result)](
      tableName,
      scans,
      (r: (ImmutableBytesWritable, Result)) => r)
  }

  /**
    * underlining wrapper all foreach functions in HBaseContext
    */
  private def hbaseForeachPartition[T](configBroadcast:
                                       Broadcast[SerializableWritable[Configuration]],
                                       it: Iterator[T],
                                       f: (Iterator[T], Connection) => Unit) = {

    val config = getConf(configBroadcast)

    applyCreds
    // specify that this is a proxy user
    val smartConn = HBaseConnectionCache.getConnection(config)
    f(it, smartConn.connection)
    smartConn.close()
  }

  private def getConf(configBroadcast: Broadcast[SerializableWritable[Configuration]]):
  Configuration = {

    if (tmpHdfsConfiguration == null && tmpHdfsConfgFile != null) {
      val fs = FileSystem.newInstance(SparkHadoopUtil.get.conf)
      val inputStream = fs.open(new Path(tmpHdfsConfgFile))
      tmpHdfsConfiguration = new Configuration(false)
      tmpHdfsConfiguration.readFields(inputStream)
      inputStream.close()
    }

    if (tmpHdfsConfiguration == null) {
      try {
        tmpHdfsConfiguration = configBroadcast.value.value
      } catch {
        case ex: Exception => logError("Unable to getConfig from broadcast", ex)
      }
    }
    tmpHdfsConfiguration
  }

  /**
    * underlining wrapper all mapPartition functions in HBaseContext
    *
    */
  private def hbaseMapPartition[K, U](
                                       configBroadcast:
                                       Broadcast[SerializableWritable[Configuration]],
                                       it: Iterator[K],
                                       mp: (Iterator[K], Connection) =>
                                         Iterator[U]): Iterator[U] = {

    val config = getConf(configBroadcast)
    applyCreds

    val smartConn = HBaseConnectionCache.getConnection(config)
    val res = mp(it, smartConn.connection)
    smartConn.close()
    res
  }

  /**
    * underlining wrapper all get mapPartition functions in HBaseContext
    */
  private class GetMapPartition[T, U](tableName: TableName,
                                      batchSize: Integer,
                                      makeGet: (T) => Get,
                                      convertResult: (Result) => U)
    extends Serializable {

    val tName = tableName.getName

    def run(iterator: Iterator[T], connection: Connection): Iterator[U] = {
      val table = connection.getTable(TableName.valueOf(tName))

      val gets = new java.util.ArrayList[Get]()
      var res = List[U]()

      while (iterator.hasNext) {
        gets.add(makeGet(iterator.next()))

        if (gets.size() == batchSize) {
          val results = table.get(gets)
          res = res ++ results.map(convertResult)
          gets.clear()
        }
      }
      if (gets.size() > 0) {
        val results = table.get(gets)
        res = res ++ results.map(convertResult)
        gets.clear()
      }
      table.close()
      res.iterator
    }
  }


  def saveAsHFile(rdd: RDD[(ImmutableBytesWritable, KeyValue)],
                  tableName: TableName,
                  regionLocator: RegionLocator
                 ) = {
    lazy val job = Job.getInstance(config, this.getClass.getName.split('$')(0))
    lazy val connection = ConnectionFactory.createConnection(config)

    //val regionLocator = new HRegionLocator(tableName, connection.asInstanceOf[ClusterConnection])
    val table = connection.getTable(tableName)
    HFileOutputFormat2.configureIncrementalLoad(job, table, regionLocator)

    // prepare path for HFiles output
    val fs = FileSystem.get(config)
    val hFilePath = new Path("/tmp", table.getName.getQualifierAsString + "_" + UUID.randomUUID())
    fs.makeQualified(hFilePath)

    try {
      rdd
        .saveAsNewAPIHadoopFile(hFilePath.toString,
          classOf[ImmutableBytesWritable], classOf[KeyValue],
          classOf[HFileOutputFormat2],
          job.getConfiguration)

      // prepare HFiles for incremental load
      // set folders permissions read/write/exec for all
      val rwx = new FsPermission("777")
      def setRecursivePermission(path: Path): Unit = {
        val listFiles = fs.listStatus(path)
        listFiles foreach { f =>
          val p = f.getPath
          fs.setPermission(p, rwx)
          if (f.isDirectory && p.getName != "_tmp") {
            // create a "_tmp" folder that can be used for HFile splitting, so that we can
            // set permissions correctly. This is a workaround for unsecured HBase. It should not
            // be necessary for SecureBulkLoadEndpoint (see https://issues.apache.org/jira/browse/HBASE-8495
            // and http://comments.gmane.org/gmane.comp.java.hadoop.hbase.user/44273)
            FileSystem.mkdirs(fs, new Path(p, "_tmp"), rwx)
            setRecursivePermission(p)
          }
        }
      }
      setRecursivePermission(hFilePath)

      val lih = new LoadIncrementalHFiles(config)
//      // deprecated method still available in hbase 1.0.0, to be replaced with the method below since hbase 1.1.0
//      lih.doBulkLoad(hFilePath, new HTable(config, table.getName))

      // this is available since hbase 1.1.xs
      lih.doBulkLoad(hFilePath, connection.getAdmin, table, regionLocator)
    } finally {
      connection.close()

      fs.deleteOnExit(hFilePath)

      // clean HFileOutputFormat2 stuff
      fs.deleteOnExit(new Path(TotalOrderPartitioner.getPartitionFile(job.getConfiguration)))
    }
  }

  def BulkPartitionedSortedGet[T, U: ClassTag](tableName: TableName,
                              batchSize: Integer,
                              rdd: RDD[T],
                              makeRowKey:(T) => (ByteArrayWrapper,ByteArrayWrapper),
                              makeGet: (T) => Get,
                              convertResult: (Result) => U): RDD[U] = {
    val conn = HBaseConnectionCache.getConnection(config)
    val regionLocator = conn.getRegionLocator(tableName)
    val startKeys = regionLocator.getStartKeys
    if (startKeys.length == 0) {
      logInfo("Table " + tableName.toString + " was not found")
    }
    val regionPartitioner = new BulkPartitioner(startKeys,10)

    val sortRDD = rdd.map(r => makeRowKey(r)).keyBy(_._1)
      .repartitionAndSortWithinPartitions(regionPartitioner).map(_._1)

    bulkGet[ByteArrayWrapper,U](tableName,batchSize,sortRDD,
      r => new Get(r.value),convertResult)

    // fixme 1. 是否可行？2. 是否会数据倾斜？ 3，写法是否合理？

  }

  /**
    * Ideal use case: Larger batches with table that span many Region Servers
    *Anti-Pattern: Not the best for huge bulk loads and may be over kill for simple smaller bulk puts/mutations.
    * @param rdd
    * @param tableName
    * @param mapFunction
    * @param numFilesPerRegionPerFamily
    * @tparam T
    */

  def bulkLoadThinRows[T](rdd:RDD[T],
                          tableName: String,
                          mapFunction: (T) => (ByteArrayWrapper, FamiliesQualifiersValues),
                          numFilesPerRegionPerFamily:Int
                         ) = {
    val conn = HBaseConnectionCache.getConnection(config)
    val tn = TableName.valueOf(tableName)
    val regionLocator = conn.getRegionLocator(tn)
    val startKeys = regionLocator.getStartKeys
    if (startKeys.length == 0) {
      logInfo("Table " + tableName.toString + " was not found")
    }
    val regionSplitPartitioner = new BulkPartitioner(startKeys,numFilesPerRegionPerFamily)

    val sortRDD = rdd.map( r => mapFunction(r))
      .repartitionAndSortWithinPartitions(regionSplitPartitioner).flatMap(x=>{
      val rowKey = x._1.value
      val familyIt = x._2.familyMap.entrySet().iterator()

      val kvs = new ListBuffer[(ImmutableBytesWritable, KeyValue)]

      while (familyIt.hasNext) {
        val familyEntry = familyIt.next()
        val family = familyEntry.getKey.value
        val qualifierIt = familyEntry.getValue.entrySet().iterator()

        //The qualifier map is a tree map so the families will be sorted
        while (qualifierIt.hasNext) {
          val qualifierEntry = qualifierIt.next()
          val qualifier = qualifierEntry.getKey.value
          val cellValue = qualifierEntry.getValue
          val kv = new KeyValue(rowKey,family,qualifier,cellValue)
          kvs += ((new ImmutableBytesWritable,kv))
        }
      }
      kvs
    })
    saveAsHFile(sortRDD, tn, regionLocator)
  }



  /**
    * Produces a ClassTag[T], which is actually just a casted ClassTag[AnyRef].
    *
    * This method is used to keep ClassTags out of the external Java API, as
    * the Java compiler cannot produce them automatically. While this
    * ClassTag-faking does please the compiler, it can cause problems at runtime
    * if the Scala API relies on ClassTags for correctness.
    *
    * Often, though, a ClassTag[AnyRef] will not lead to incorrect behavior,
    * just worse performance or security issues.
    * For instance, an Array of AnyRef can hold any type T, but may lose primitive
    * specialization.
    */
  private
  def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
}



object LatestHBaseContextCache {
  var latest: HBaseContext = null
}